# Copyright 2020 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Makes a prediction using the AutoML Tables model specified in the config.
Assumes that a model with the specified display name has already been trained,
and that the features for prediction have been generated. Prediction
jobs typically take 5-25 minutes to complete, based on the volume of data.
See https://cloud.google.com/automl-tables/docs/predict-batch for details.
"""

import logging
import sys

import google.api_core.exceptions
from google.cloud import automl_v1beta1 as automl
from google.cloud import bigquery

import utils

logging.basicConfig(level=logging.DEBUG)


def main():
  """Executes batch prediction using a model trained on AutoML.

  Uses parameters specified in the configuration file, to determine the
  read and write locations for features and predictions.
  See the configuration file for more details.

  1. Runs batch prediction operation on AutoML service.
  2. Copies contents of the dataset autogenerated by AutoML into the
     specified dataset.
  3. Copies failed predictions to a new table unless all succeeded.
  4. Deletes the autogenerated dataset if the user has permissions.
  """

  config_path = utils.parse_arguments(sys.argv).config_path
  config = utils.read_config(config_path)

  # Defining subconfigs explicitly for readability.
  global_config = config['global']

  # Authenticate using AutoML service account credentials.
  automl_client = automl.TablesClient(
      project=global_config['destination_project_id'],
      region=global_config['automl_compute_region'],
  )

  # Authenticate using Application Default Credentials.
  bq_client = bigquery.Client(
      project=global_config['destination_project_id'],
  )

  # Batch prediction operation is a Long Running Operation, .result() performs
  # a synchronous wait for the prediction to complete before progressing.
  # Input URI is the full path of the table of features for prediction, and
  # output URI is the full path of the project to write predictions to.
  # See https://cloud.google.com/automl-tables/docs/predict-batch for details.
  batch_prediction_operation = automl_client.batch_predict(
      bigquery_input_uri='bq://{}.{}.{}'.format(
          global_config['destination_project_id'],
          global_config['destination_dataset'],
          global_config['features_predict_table']),
      bigquery_output_uri='bq://{}'.format(
          global_config['destination_project_id']),
      model_display_name=global_config['model_display_name'],
  )
  batch_prediction_operation.result()

  # AutoML generates a new dataset in the project with a destination that is
  # not user specified, and writes predictions and failed predictions to
  # "predictions" and "errors" respectively. A failed prediction, for example,
  # may be the result of a numeric column that recieved a string. The batch
  # prediction response contains the dataset uri with format
  # "bq://project_id.output_dataset".
  automl_dataset_id = (
      batch_prediction_operation
      .metadata
      .batch_predict_details
      .output_info
      .bigquery_output_dataset
  ).split('bq://')[-1]
  predictions_table = bq_client.get_table(automl_dataset_id + '.predictions')
  failed_predictions_table = bq_client.get_table(automl_dataset_id + '.errors')

  # Copy predictions to dataset, fails if table already exists.
  bq_client.copy_table(
      sources=predictions_table,
      destination='{}.{}.{}'.format(
          global_config['destination_project_id'],
          global_config['destination_dataset'],
          global_config['predictions_table']),
  ).result()

  # Copy the failed predictions table only if it is not empty.
  if failed_predictions_table.num_rows > 0:
    bq_client.copy_table(
        sources=failed_predictions_table,
        destination='{}.{}.{}'.format(
            global_config['destination_project_id'],
            global_config['destination_dataset'],
            global_config['failed_predictions_table']),
    ).result()
    logging.warning("%d rows in the batch prediction job failed.",
                    failed_predictions_table.num_rows)

  # Delete the dataset created by AutoML, catches Forbidden exception raised if
  # account does not have BQ Data Owner permissions.
  try:
    bq_client.delete_dataset(automl_dataset_id, delete_contents=True)
  except google.api_core.exceptions.Forbidden:
    logging.warning(
        "Failed to delete BQ dataset generated by AutoML batch prediction."
        " Requires BQ Data Owner permissions to delete.")


if __name__ == '__main__':
  main()
